/*
 *  Copyright (c) 2004-present, Facebook, Inc.
 *  All rights reserved.
 *
 *  This source code is licensed under the BSD-style license found in the
 *  LICENSE file in the root directory of this source tree. An additional grant
 *  of patent rights can be found in the PATENTS file in the same directory.
 *
 */
#include "fboss/agent/FabricLinkMonitoringManager.h"

#include <fb303/ServiceData.h>
#include <folly/Memory.h>
#include <folly/io/Cursor.h>
#include <folly/io/IOBuf.h>
#include <folly/logging/xlog.h>
#include <gtest/gtest.h>
#include "fboss/agent/FbossError.h"
#include "fboss/agent/RxPacket.h"
#include "fboss/agent/SwSwitch.h"
#include "fboss/agent/TxPacket.h"
#include "fboss/agent/hw/mock/MockHwSwitch.h"
#include "fboss/agent/hw/mock/MockPlatform.h"
#include "fboss/agent/hw/mock/MockRxPacket.h"
#include "fboss/agent/state/Port.h"
#include "fboss/agent/test/HwTestHandle.h"
#include "fboss/agent/test/TestUtils.h"
#include "gmock/gmock.h"

using ::testing::_;
using ::testing::AtLeast;

using namespace facebook::fboss;
using folly::MacAddress;
using folly::io::Cursor;
using folly::io::RWPrivateCursor;
using std::make_shared;
using std::shared_ptr;
using std::unique_ptr;

namespace {

constexpr size_t kFabricLinkMonitoringPacketSize = 480;

unique_ptr<HwTestHandle> setupTestHandle() {
  auto state = testStateAWithPortsUp();
  addSwitchInfo(
      state,
      cfg::SwitchType::FABRIC,
      0,
      cfg::AsicType::ASIC_TYPE_MOCK,
      cfg::switch_config_constants::DEFAULT_PORT_ID_RANGE_MIN(),
      cfg::switch_config_constants::DEFAULT_PORT_ID_RANGE_MAX(),
      0,
      std::nullopt,
      std::nullopt,
      MockPlatform::getMockLocalMac().toString());

  for (auto& portMap : std::as_const(*state->getPorts())) {
    for (auto& port : std::as_const(*portMap.second)) {
      auto newPort = port.second->modify(&state);
      newPort->setPortType(cfg::PortType::FABRIC_PORT);
    }
  }

  return createTestHandle(state);
}

TxMatchFn checkFabricMonitoringPacket() {
  return [=](const TxPacket* pkt) {
    const auto* buf = pkt->buf();
    const auto chainlen = buf->computeChainDataLength();

    EXPECT_EQ(chainlen, kFabricLinkMonitoringPacketSize);

    Cursor c(buf);

    auto sequenceNumber = c.readBE<uint64_t>();
    EXPECT_GE(sequenceNumber, 0);

    auto portId = c.readBE<uint32_t>();
    EXPECT_GE(portId, 0);

    auto payloadPattern =
        FabricLinkMonitoringManager::getPayloadPattern(sequenceNumber);
    size_t payloadSize =
        kFabricLinkMonitoringPacketSize - sizeof(uint64_t) - sizeof(uint32_t);
    for (size_t i = 0; i < payloadSize; i += 4) {
      auto value = c.readBE<uint32_t>();
      EXPECT_EQ(value, payloadPattern);
    }
  };
}

std::unique_ptr<RxPacket> createFabricMonitoringRxPacket(
    SwSwitch* /*sw*/,
    const PortID& portId,
    uint64_t sequenceNumber) {
  std::vector<uint8_t> data(kFabricLinkMonitoringPacketSize);

  // Serialize sequenceNumber (uint64_t) in big-endian order
  uint64_t seq_be = htobe64(sequenceNumber);
  std::memcpy(data.data(), &seq_be, sizeof(seq_be));

  // Serialize portId (uint32_t) in big-endian order
  uint32_t port_be = htobe32(static_cast<uint32_t>(portId));
  std::memcpy(data.data() + 8, &port_be, sizeof(port_be));

  // Fill the rest of the buffer with pattern
  size_t fillStart = sizeof(uint64_t) + sizeof(uint32_t);
  ;
  auto fillPattern =
      FabricLinkMonitoringManager::getPayloadPattern(sequenceNumber);
  for (size_t i = fillStart; i + 4 <= kFabricLinkMonitoringPacketSize; i += 4) {
    std::memcpy(data.data() + i, &fillPattern, sizeof(fillPattern));
  }

  auto buf = folly::IOBuf::copyBuffer(data.data(), data.size());
  auto rxPkt = std::make_unique<MockRxPacket>(std::move(buf));
  rxPkt->setSrcPort(portId);
  return rxPkt;
}

} // namespace

TEST(FabricLinkMonitoringManagerTest, CreateManager) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  EXPECT_NE(manager, nullptr);
}

TEST(FabricLinkMonitoringManagerTest, SendPacketsOnFabricPorts) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  EXPECT_HW_CALL(
      sw,
      sendPacketOutOfPortSyncForPktType_(
          TxPacketMatcher::createMatcher(
              "Fabric Monitoring Packet", checkFabricMonitoringPacket()),
          _,
          _))
      .Times(AtLeast(1));

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  manager->start();

  waitForStateUpdates(sw);
  waitForBackgroundThread(sw);

  manager->stop();
}

TEST(FabricLinkMonitoringManagerTest, PacketPayloadValidation) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  manager->start();

  waitForStateUpdates(sw);
  waitForBackgroundThread(sw);

  PortID testPort(1);

  // Get initial stats
  auto statsBefore = manager->getFabricLinkMonPortStats(testPort);

  auto rxPkt = createFabricMonitoringRxPacket(sw, testPort, 0);
  Cursor cursor(rxPkt->buf());

  EXPECT_NO_THROW(manager->handlePacket(std::move(rxPkt), std::move(cursor)));

  // Verify that rxCount was incremented, confirming good payload was processed
  auto statsAfter = manager->getFabricLinkMonPortStats(testPort);
  EXPECT_EQ(*statsAfter.rxCount(), *statsBefore.rxCount() + 1);
  EXPECT_EQ(
      *statsAfter.invalidPayloadCount(), *statsBefore.invalidPayloadCount());

  manager->stop();
}

TEST(FabricLinkMonitoringManagerTest, PacketSizeValidation) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  auto pkt = sw->allocatePacket(kFabricLinkMonitoringPacketSize);
  EXPECT_EQ(
      pkt->buf()->computeChainDataLength(), kFabricLinkMonitoringPacketSize);
}

TEST(FabricLinkMonitoringManagerTest, MultipleStartStopManager) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);

  for (int i = 0; i < 10; ++i) {
    EXPECT_NO_THROW(manager->start());
    waitForBackgroundThread(sw);
    EXPECT_NO_THROW(manager->stop());
  }
}

TEST(FabricLinkMonitoringManagerTest, PacketSequenceNumberIncrement) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  uint64_t lastSequenceNumber = 0;
  int packetCount = 0;

  auto getPortIdToTrack = [&]() {
    auto state = sw->getState();
    for (auto& portMap : std::as_const(*state->getPorts())) {
      for (auto& port : std::as_const(*portMap.second)) {
        if (port.second->getPortType() == cfg::PortType::FABRIC_PORT) {
          return static_cast<int>(port.second->getID());
        }
      }
    }
    return -1;
  };

  auto portIdToTrack = getPortIdToTrack();
  EXPECT_HW_CALL(
      sw,
      sendPacketOutOfPortSyncForPktType_(
          TxPacketMatcher::createMatcher(
              "Fabric Monitoring Packet",
              [&lastSequenceNumber, &packetCount, &portIdToTrack](
                  const TxPacket* pkt) {
                Cursor c(pkt->buf());
                auto seqNum = c.readBE<uint64_t>();
                auto portId = c.readBE<uint32_t>();
                if (portId == portIdToTrack) {
                  if (packetCount > 0) {
                    EXPECT_GT(seqNum, lastSequenceNumber);
                  }
                  lastSequenceNumber = seqNum;
                  packetCount++;
                }
              }),
          _,
          _))
      .Times(AtLeast(5));

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  manager->start();

  waitForStateUpdates(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);

  manager->stop();
}

TEST(FabricLinkMonitoringManagerTest, OnlyFabricPortsReceivePackets) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  auto state = sw->getState();
  auto newState = state->clone();

  for (auto& portMap : std::as_const(*newState->getPorts())) {
    for (auto& port : std::as_const(*portMap.second)) {
      auto newPort = port.second->modify(&newState);
      if (port.second->getID() == PortID(1)) {
        newPort->setPortType(cfg::PortType::INTERFACE_PORT);
      } else {
        newPort->setPortType(cfg::PortType::FABRIC_PORT);
      }
    }
  }

  sw->updateStateBlocking(
      "Update port types", [&](const auto&) { return newState; });

  int fabricPortCount = 0;
  for (const auto& portMap : std::as_const(*sw->getState()->getPorts())) {
    for (const auto& [portId, port] : std::as_const(*portMap.second)) {
      (void)portId;
      if (port->getPortType() == cfg::PortType::FABRIC_PORT &&
          port->isPortUp()) {
        fabricPortCount++;
      }
    }
  }

  EXPECT_HW_CALL(
      sw,
      sendPacketOutOfPortSyncForPktType_(
          TxPacketMatcher::createMatcher(
              "Fabric Monitoring Packet", checkFabricMonitoringPacket()),
          _,
          _))
      .Times(AtLeast(fabricPortCount));

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  manager->start();

  waitForStateUpdates(sw);
  waitForBackgroundThread(sw);

  manager->stop();
}

TEST(FabricLinkMonitoringManagerTest, PacketDropDetectionOnMaxPendingExceeded) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  EXPECT_HW_CALL(
      sw,
      sendPacketOutOfPortSyncForPktType_(
          TxPacketMatcher::createMatcher(
              "Fabric Monitoring Packet", checkFabricMonitoringPacket()),
          _,
          _))
      .Times(AtLeast(4));

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  manager->start();

  // Wait for multiple packets to be sent
  waitForStateUpdates(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);

  manager->stop();
}

TEST(FabricLinkMonitoringManagerTest, MultipleSequentialPackets) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  EXPECT_HW_CALL(
      sw,
      sendPacketOutOfPortSyncForPktType_(
          TxPacketMatcher::createMatcher(
              "Fabric Monitoring Packet", checkFabricMonitoringPacket()),
          _,
          _))
      .Times(AtLeast(5));

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  manager->start();

  // Send multiple rounds of packets
  waitForStateUpdates(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);

  // Verify packets were sent
  PortID testPort(1);
  auto statsAfter = manager->getFabricLinkMonPortStats(testPort);
  EXPECT_GT(*statsAfter.txCount(), 0);
  // Since no packets are being received, pending sequence numbers should
  // accumulate
  EXPECT_GT(manager->getPendingSequenceNumbers(testPort).size(), 0);

  manager->stop();
}

TEST(FabricLinkMonitoringManagerTest, DetectInvalidPayload) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  manager->start();

  waitForStateUpdates(sw);

  // Create a packet with invalid payload
  PortID portId = PortID(1);

  // Get initial stats
  auto statsBefore = manager->getFabricLinkMonPortStats(portId);
  std::vector<uint8_t> data(kFabricLinkMonitoringPacketSize);

  // Serialize sequenceNumber
  uint64_t seq_be = htobe64(0);
  std::memcpy(data.data(), &seq_be, sizeof(seq_be));

  // Serialize portId
  uint32_t port_be = htobe32(static_cast<uint32_t>(portId));
  std::memcpy(data.data() + 8, &port_be, sizeof(port_be));

  // Fill with incorrect pattern
  size_t fillStart = sizeof(uint64_t) + sizeof(uint32_t);
  uint32_t wrongPattern = 0xDEADBEEF; // Wrong pattern
  for (size_t i = fillStart; i + 4 <= kFabricLinkMonitoringPacketSize; i += 4) {
    std::memcpy(data.data() + i, &wrongPattern, sizeof(wrongPattern));
  }

  auto buf = folly::IOBuf::copyBuffer(data.data(), data.size());
  auto rxPkt = std::make_unique<MockRxPacket>(std::move(buf));
  rxPkt->setSrcPort(portId);

  folly::io::Cursor cursor(rxPkt->buf());
  manager->handlePacket(std::move(rxPkt), cursor);

  // Verify invalidPayloadCount was incremented
  auto statsAfter = manager->getFabricLinkMonPortStats(portId);
  EXPECT_EQ(
      *statsAfter.invalidPayloadCount(),
      *statsBefore.invalidPayloadCount() + 1);
  EXPECT_EQ(*statsAfter.rxCount(), *statsBefore.rxCount());

  manager->stop();
}

TEST(FabricLinkMonitoringManagerTest, DetectPortIdMismatch) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  manager->start();

  waitForStateUpdates(sw);

  // Create a packet with mismatched port ID
  PortID actualPort = PortID(1);

  // Get initial stats
  auto statsBefore = manager->getFabricLinkMonPortStats(actualPort);
  PortID wrongPortInPacket = PortID(99);
  std::vector<uint8_t> data(kFabricLinkMonitoringPacketSize);

  // Serialize sequenceNumber
  uint64_t seq_be = htobe64(0);
  std::memcpy(data.data(), &seq_be, sizeof(seq_be));

  // Serialize wrong portId
  uint32_t port_be = htobe32(static_cast<uint32_t>(wrongPortInPacket));
  std::memcpy(data.data() + 8, &port_be, sizeof(port_be));

  // Fill with correct pattern
  size_t fillStart = sizeof(uint64_t) + sizeof(uint32_t);
  auto fillPattern = FabricLinkMonitoringManager::getPayloadPattern(0);
  for (size_t i = fillStart; i + 4 <= kFabricLinkMonitoringPacketSize; i += 4) {
    std::memcpy(data.data() + i, &fillPattern, sizeof(fillPattern));
  }

  auto buf = folly::IOBuf::copyBuffer(data.data(), data.size());
  auto rxPkt = std::make_unique<MockRxPacket>(std::move(buf));
  rxPkt->setSrcPort(actualPort);

  folly::io::Cursor cursor(rxPkt->buf());
  manager->handlePacket(std::move(rxPkt), cursor);

  // Verify invalidPayloadCount was incremented (port ID mismatch)
  auto statsAfter = manager->getFabricLinkMonPortStats(actualPort);
  EXPECT_EQ(
      *statsAfter.invalidPayloadCount(),
      *statsBefore.invalidPayloadCount() + 1);
  EXPECT_EQ(*statsAfter.rxCount(), *statsBefore.rxCount());

  manager->stop();
}

TEST(FabricLinkMonitoringManagerTest, OutOfOrderPacketHandling) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  EXPECT_HW_CALL(
      sw,
      sendPacketOutOfPortSyncForPktType_(
          TxPacketMatcher::createMatcher(
              "Fabric Monitoring Packet", checkFabricMonitoringPacket()),
          _,
          _))
      .Times(AtLeast(5));

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  manager->start();

  // Wait for multiple rounds of packets to build up pending sequence numbers
  waitForStateUpdates(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);
  waitForBackgroundThread(sw);

  // Stop the manager to avoid race condition where background thread
  // sends another packet while we're reading stats
  manager->stop();

  PortID portId = PortID(1);
  auto statsBefore = manager->getFabricLinkMonPortStats(portId);

  // Verify packets were sent and are pending
  EXPECT_GT(*statsBefore.txCount(), 0);

  // Get the pending sequence numbers
  auto pendingSeqNums = manager->getPendingSequenceNumbers(portId);
  EXPECT_GT(pendingSeqNums.size(), 0);
  ASSERT_FALSE(pendingSeqNums.empty());

  // Find the last (highest) sequence number in the pending list
  uint64_t lastSeqNum =
      *std::max_element(pendingSeqNums.begin(), pendingSeqNums.end());

  // Count how many sequence numbers are before the last one
  size_t expectedDropCount = 0;
  for (auto seqNum : pendingSeqNums) {
    if (seqNum < lastSeqNum) {
      expectedDropCount++;
    }
  }

  // Receive the last packet in the pending list
  // This should cause all earlier pending packets to be marked as dropped
  auto rxPkt = createFabricMonitoringRxPacket(sw, portId, lastSeqNum);
  folly::io::Cursor cursor(rxPkt->buf());
  manager->handlePacket(std::move(rxPkt), cursor);

  auto statsAfter = manager->getFabricLinkMonPortStats(portId);
  EXPECT_EQ(*statsAfter.rxCount(), *statsBefore.rxCount() + 1);
  // All packets with sequence numbers < lastSeqNum should now be dropped
  EXPECT_EQ(
      *statsAfter.droppedCount(),
      *statsBefore.droppedCount() + expectedDropCount);
}

TEST(FabricLinkMonitoringManagerTest, GetPayloadPatternAlternates) {
  // Verify payload pattern alternates between odd and even sequence numbers
  EXPECT_EQ(FabricLinkMonitoringManager::getPayloadPattern(0), 0x5A5A5A5A);
  EXPECT_EQ(FabricLinkMonitoringManager::getPayloadPattern(1), 0xA5A5A5A5);
  EXPECT_EQ(FabricLinkMonitoringManager::getPayloadPattern(2), 0x5A5A5A5A);
  EXPECT_EQ(FabricLinkMonitoringManager::getPayloadPattern(3), 0xA5A5A5A5);
  EXPECT_EQ(FabricLinkMonitoringManager::getPayloadPattern(100), 0x5A5A5A5A);
  EXPECT_EQ(FabricLinkMonitoringManager::getPayloadPattern(101), 0xA5A5A5A5);
}

TEST(FabricLinkMonitoringManagerTest, PortsDownNotSent) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  // Bring all ports down
  auto state = sw->getState();
  auto newState = state->clone();
  for (auto& portMap : std::as_const(*newState->getPorts())) {
    for (auto& port : std::as_const(*portMap.second)) {
      auto newPort = port.second->modify(&newState);
      newPort->setOperState(false);
    }
  }
  sw->updateStateBlocking(
      "bring ports down", [newState](const auto&) { return newState; });

  // Should not send any packets when ports are down
  EXPECT_HW_CALL(
      sw,
      sendPacketOutOfPortSyncForPktType_(
          TxPacketMatcher::createMatcher(
              "Fabric Monitoring Packet", checkFabricMonitoringPacket()),
          _,
          _))
      .Times(0);

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  manager->start();

  waitForStateUpdates(sw);
  waitForBackgroundThread(sw);

  manager->stop();
}

TEST(FabricLinkMonitoringManagerTest, HandlePacketWithNoPendingSequence) {
  auto handle = setupTestHandle();
  auto sw = handle->getSw();

  auto manager = std::make_unique<FabricLinkMonitoringManager>(sw);
  manager->start();

  waitForStateUpdates(sw);
  waitForBackgroundThread(sw);

  // Try to receive a packet when no packets were sent
  PortID portId = PortID(1);
  auto rxPkt = createFabricMonitoringRxPacket(sw, portId, 999);
  folly::io::Cursor cursor(rxPkt->buf());

  // handlePacket should not throw even when receiving unexpected packet
  EXPECT_NO_THROW(manager->handlePacket(std::move(rxPkt), cursor));

  // Verify noPendingSeqNumCount was incremented
  auto statsAfter = manager->getFabricLinkMonPortStats(portId);
  EXPECT_GT(*statsAfter.noPendingSeqNumCount(), 0);

  manager->stop();
}
